{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data import sampler\n",
    "import torchvision.datasets as dset\n",
    "import torchvision.transforms as T\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "NUM_TRAIN = 49000\n",
    "transform = T.Compose([T.ToTensor(), T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])\n",
    "\n",
    "cifar10_train = dset.CIFAR10('./cs231n/datasets/', train=True, download=True, transform=transform)\n",
    "loader_train = DataLoader(cifar10_train, batch_size=64, sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))\n",
    "\n",
    "cifar10_val = dset.CIFAR10('./cs231n/datasets/', train=True, download=True, transform=transform)\n",
    "loader_val = DataLoader(cifar10_val, batch_size=64, sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 50000)))\n",
    "\n",
    "cifar10_test = dset.CIFAR10('./cs231n/datasets/', train=False, download=True, transform=transform)\n",
    "loader_test = DataLoader(cifar10_test, batch_size=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "using device: cpu\n"
     ]
    }
   ],
   "source": [
    "USE_GPU = True\n",
    "dtype = torch.float32\n",
    "print_every = 100\n",
    "\n",
    "if USE_GPU and torch.cuda.is_available():\n",
    "    device = torch.device('cuda')\n",
    "else:\n",
    "    device = torch.device('cpu')\n",
    "    \n",
    "print('using device:', device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Before flattening:  torch.Size([2, 1, 3, 2])\n",
      "After flattening:  torch.Size([2, 6])\n"
     ]
    }
   ],
   "source": [
    "def flatten(x):\n",
    "    N = x.shape[0]\n",
    "    return x.view(N, -1)\n",
    "\n",
    "def test_flatten():\n",
    "    x = torch.arange(12).view(2, 1, 3, 2)\n",
    "    print('Before flattening: ', x.shape)\n",
    "    print('After flattening: ', flatten(x).shape)\n",
    "    \n",
    "test_flatten()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Barebones PyTorch\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Two-Layer Network\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([64, 10])\n"
     ]
    }
   ],
   "source": [
    "def two_layer_fc(x, params):\n",
    "    x = flatten(x)\n",
    "    w1, w2 = params\n",
    "    \n",
    "    x = F.relu(x.mm(w1))\n",
    "    x = x.mm(w2)\n",
    "    \n",
    "    return x\n",
    "\n",
    "def two_layer_fc_test():\n",
    "    hidden_layer_size = 42\n",
    "    x = torch.zeros((64, 50), dtype=dtype)\n",
    "    w1 = torch.zeros((50, hidden_layer_size), dtype=dtype)\n",
    "    w2 = torch.zeros((hidden_layer_size, 10), dtype=dtype)\n",
    "    params = (w1, w2)\n",
    "    scores = two_layer_fc(x, params)\n",
    "    print(scores.shape)\n",
    "    \n",
    "two_layer_fc_test()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Three-Layer ConvNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def three_layer_convnet(x, params):\n",
    "    w1, b1, w2, b2, w3, b3 = params\n",
    "    scores = None\n",
    "    \n",
    "    out1 = F.conv2d(x, w1, bias=b1, stride=1, padding=(2, 2))\n",
    "    relu1 = F.relu(out1)\n",
    "    out2 = F.conv2d(relu1, w2, bias=b2, stride=1, padding=(1, 1))\n",
    "    relu2 = F.relu(out2)\n",
    "    scores = torch.mm(flatten(relu2), w3) + b3\n",
    "    \n",
    "    return scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([64, 10])\n"
     ]
    }
   ],
   "source": [
    "def three_layer_convnet_test():\n",
    "    x = torch.zeros((64, 3, 32, 32), dtype=dtype)\n",
    "    \n",
    "    w1 = torch.zeros((6, 3, 5, 5), dtype=dtype)\n",
    "    b1 = torch.zeros((6, ), dtype=dtype)\n",
    "    w2 = torch.zeros((9, 6, 3, 3), dtype=dtype)\n",
    "    b2 = torch.zeros((9, ), dtype=dtype)\n",
    "    w3 = torch.zeros((32*32*9, 10), dtype=dtype)\n",
    "    b3 = torch.zeros((10, ), dtype=dtype)\n",
    "    \n",
    "    params = (w1, b1, w2, b2, w3, b3)\n",
    "    scores = three_layer_convnet(x, params)\n",
    "    print(scores.shape)\n",
    "    \n",
    "three_layer_convnet_test()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 137,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.5531,  1.6935,  0.9294,  0.5106, -1.2486],\n",
       "        [-0.2500,  1.1149, -0.2589, -0.5693,  0.5004],\n",
       "        [-1.2586, -0.0858, -0.2165,  0.1567,  1.4835]])"
      ]
     },
     "execution_count": 137,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def random_weight(shape):\n",
    "    if len(shape) == 2:\n",
    "        mode = 'fan_out'\n",
    "    else:\n",
    "        mode = 'fan_in'\n",
    "    weight = torch.empty(shape)\n",
    "    nn.init.kaiming_normal_(weight, mode=mode)\n",
    "    weight.requires_grad = True\n",
    "    return weight\n",
    "\n",
    "def zero_weight(shape):\n",
    "    return torch.zeros(shape, device=device, dtype=dtype, requires_grad=True)\n",
    "\n",
    "# def random_weight(shape):\n",
    "#     if len(shape) == 2:\n",
    "#         fan_in = shape[0]\n",
    "#     else:\n",
    "#         fan_in = np.prod(shape[1:])\n",
    "#     w = torch.randn(shape, device=device, dtype=dtype) * np.sqrt(2. / fan_in)\n",
    "#     w.requires_grad = True\n",
    "#     return w\n",
    "\n",
    "random_weight((3, 5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_accuracy(loader, model_fn, params):\n",
    "    split = 'val' if loader.dataset.train else 'test'\n",
    "    print('Checking accuracy on the %s set' % split)\n",
    "    num_correct, num_samples = 0, 0\n",
    "    with torch.no_grad():\n",
    "        for x, y in loader:\n",
    "            x = x.to(device=device, dtype=dtype)\n",
    "            y = y.to(device=device, dtype=torch.long)\n",
    "            scores = model_fn(x, params)\n",
    "            max_scores, preds = torch.max(scores, dim=1)\n",
    "            num_correct += (preds == y).sum()\n",
    "            num_samples += x.shape[0]\n",
    "\n",
    "        acc = float(num_correct) / num_samples\n",
    "        print('Got %d / %d correct (%.2f%%)' % (num_correct, num_samples, 100 * acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model_fn, params, learning_rate):\n",
    "    for t, (x, y) in enumerate(loader_train):\n",
    "        x = x.to(device=device, dtype=dtype)\n",
    "        y = y.to(device=device, dtype=torch.long)\n",
    "        \n",
    "        scores = model_fn(x, params)\n",
    "        loss = F.cross_entropy(scores, y)\n",
    "        \n",
    "        loss.backward()\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            for w in params:\n",
    "                w -= learning_rate * w.grad\n",
    "                \n",
    "                w.grad.zero_()\n",
    "                \n",
    "        if t % print_every == 0:\n",
    "            print('Iteration %d, loss = %.4f' % (t, loss.item()))\n",
    "            check_accuracy(loader_val, model_fn, params)\n",
    "            print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train a Two-Layer Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration 0, loss = 4.0419\n",
      "Checking accuracy on the val set\n",
      "Got 144 / 1000 correct (14.40%)\n",
      "\n",
      "Iteration 100, loss = 1.9465\n",
      "Checking accuracy on the val set\n",
      "Got 343 / 1000 correct (34.30%)\n",
      "\n",
      "Iteration 200, loss = 2.5488\n",
      "Checking accuracy on the val set\n",
      "Got 362 / 1000 correct (36.20%)\n",
      "\n",
      "Iteration 300, loss = 2.1277\n",
      "Checking accuracy on the val set\n",
      "Got 365 / 1000 correct (36.50%)\n",
      "\n",
      "Iteration 400, loss = 2.3016\n",
      "Checking accuracy on the val set\n",
      "Got 395 / 1000 correct (39.50%)\n",
      "\n",
      "Iteration 500, loss = 1.6212\n",
      "Checking accuracy on the val set\n",
      "Got 426 / 1000 correct (42.60%)\n",
      "\n",
      "Iteration 600, loss = 1.6771\n",
      "Checking accuracy on the val set\n",
      "Got 421 / 1000 correct (42.10%)\n",
      "\n",
      "Iteration 700, loss = 1.7955\n",
      "Checking accuracy on the val set\n",
      "Got 430 / 1000 correct (43.00%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "hidden_layer_size = 4000\n",
    "learning_rate = 1e-2\n",
    "\n",
    "w1 = random_weight((3 * 32 * 32, hidden_layer_size))\n",
    "w2 = random_weight((hidden_layer_size, 10))\n",
    "\n",
    "train(two_layer_fc, [w1, w2], learning_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training a ConvNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration 0, loss = 3.2404\n",
      "Checking accuracy on the val set\n",
      "Got 103 / 1000 correct (10.30%)\n",
      "\n",
      "Iteration 100, loss = 2.0333\n",
      "Checking accuracy on the val set\n",
      "Got 333 / 1000 correct (33.30%)\n",
      "\n",
      "Iteration 200, loss = 1.8083\n",
      "Checking accuracy on the val set\n",
      "Got 383 / 1000 correct (38.30%)\n",
      "\n",
      "Iteration 300, loss = 1.6545\n",
      "Checking accuracy on the val set\n",
      "Got 418 / 1000 correct (41.80%)\n",
      "\n",
      "Iteration 400, loss = 1.7084\n",
      "Checking accuracy on the val set\n",
      "Got 450 / 1000 correct (45.00%)\n",
      "\n",
      "Iteration 500, loss = 1.4164\n",
      "Checking accuracy on the val set\n",
      "Got 462 / 1000 correct (46.20%)\n",
      "\n",
      "Iteration 600, loss = 1.3614\n",
      "Checking accuracy on the val set\n",
      "Got 460 / 1000 correct (46.00%)\n",
      "\n",
      "Iteration 700, loss = 1.5666\n",
      "Checking accuracy on the val set\n",
      "Got 455 / 1000 correct (45.50%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "learning_rate = 3e-3\n",
    "channel_1 = 32\n",
    "channel_2 = 16\n",
    "\n",
    "w1 = random_weight((channel_1, 3, 5, 5))\n",
    "b1 = zero_weight((channel_1, ))\n",
    "w2 = random_weight((channel_2, channel_1, 3, 3))\n",
    "b2 = zero_weight((channel_2, ))\n",
    "w3 = random_weight((32*32*channel_2, 10))\n",
    "b3 = zero_weight((10, ))\n",
    "\n",
    "params = (w1, b1, w2, b2, w3, b3)\n",
    "train(three_layer_convnet, params, learning_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "PyTorch Module API\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Two-Layer Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 147,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([64, 10])\n"
     ]
    }
   ],
   "source": [
    "class TwoLayerFC(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, num_classes):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(input_size, hidden_size, bias=True)\n",
    "        nn.init.kaiming_normal_(self.fc1.weight)\n",
    "        self.fc2 = nn.Linear(hidden_size, num_classes, bias=True)\n",
    "        nn.init.kaiming_normal_(self.fc2.weight)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = flatten(x)\n",
    "        scores = self.fc2(F.relu(self.fc1(x)))\n",
    "        return scores\n",
    "\n",
    "def test_TwoLayerFC():\n",
    "    input_size = 50\n",
    "    x = torch.zeros((64, input_size), dtype=dtype)\n",
    "    model = TwoLayerFC(input_size, 42, 10)\n",
    "    scores = model.forward(x)\n",
    "    print(scores.shape)\n",
    "\n",
    "test_TwoLayerFC()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Three-Layer ConvNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([64, 10])\n"
     ]
    }
   ],
   "source": [
    "class ThreeLayerConvNet(nn.Module):\n",
    "    def __init__(self, params):\n",
    "        super().__init__()\n",
    "        in_channel = params.get('in_channel')\n",
    "        channel_1 = params.get('channel_1')\n",
    "        channel_2 = params.get('channel_2')\n",
    "        num_classes = params.get('num_classes')\n",
    "\n",
    "        self.conv1 = nn.Conv2d(in_channel, channel_1, kernel_size=(5, 5), stride=1, padding=(2, 2))\n",
    "        nn.init.kaiming_normal_(self.conv1.weight)\n",
    "        self.conv2 = nn.Conv2d(channel_1, channel_2, kernel_size=(3, 3), stride=1, padding=(1, 1))\n",
    "        nn.init.kaiming_normal_(self.conv2.weight)\n",
    "        self.fc = nn.Linear(channel_2*32*32, num_classes)\n",
    "        nn.init.kaiming_normal_(self.fc.weight)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out_1 = F.relu(self.conv1(x))\n",
    "        out_2 = F.relu(self.conv2(out_1))\n",
    "        scores = self.fc(flatten(out_2))\n",
    "        return scores\n",
    "\n",
    "def test_ThreeLayerConvNet():\n",
    "    x = torch.zeros((64, 3, 32, 32), dtype=dtype)\n",
    "    kwargs = {'in_channel': 3, 'channel_1': 12, 'channel_2': 8, 'num_classes':10}\n",
    "    model = ThreeLayerConvNet(kwargs)\n",
    "    scores = model.forward(x)\n",
    "    print(scores.shape)\n",
    "\n",
    "test_ThreeLayerConvNet()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_accuracy_2(loader, model):\n",
    "    if loader.dataset.train:\n",
    "        print('Checking accuracy on validation set')\n",
    "    else:\n",
    "        print('Checking accuracy on test set')\n",
    "        \n",
    "    num_samples = 0\n",
    "    num_correct = 0\n",
    "    model.eval()\n",
    "    for x, y in loader:\n",
    "        x = x.to(device=device, dtype=dtype)\n",
    "        y = y.to(device=device, dtype=torch.long)\n",
    "        scores = model(x)\n",
    "        _, preds = torch.max(scores, dim=1)\n",
    "        num_correct += (preds == y).sum()\n",
    "        num_samples += x.shape[0]\n",
    "    acc = float(num_correct) / num_samples\n",
    "    print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train2(model, optimizer, epochs=1):\n",
    "    model = model.to(device=device)\n",
    "    for e in range(epochs):\n",
    "        for t, (x, y) in enumerate(loader_train):\n",
    "            model.train()\n",
    "            x = x.to(device=device, dtype=dtype)\n",
    "            y = y.to(device=device, dtype=torch.long)\n",
    "            \n",
    "            scores = model(x)\n",
    "            loss = F.cross_entropy(scores, y)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            if t % print_every == 0:\n",
    "                print('Iteration %d, loss = %.4f' % (t, loss.item()))\n",
    "                check_accuracy_2(loader_val, model)\n",
    "                print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train a Two-Layer Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 159,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration 0, loss = 2.8094\n",
      "Checking accuracy on validation set\n",
      "Got 161 / 1000 correct (16.10)\n",
      "\n",
      "Iteration 100, loss = 2.2876\n",
      "Checking accuracy on validation set\n",
      "Got 335 / 1000 correct (33.50)\n",
      "\n",
      "Iteration 200, loss = 1.7721\n",
      "Checking accuracy on validation set\n",
      "Got 415 / 1000 correct (41.50)\n",
      "\n",
      "Iteration 300, loss = 1.9384\n",
      "Checking accuracy on validation set\n",
      "Got 392 / 1000 correct (39.20)\n",
      "\n",
      "Iteration 400, loss = 2.2530\n",
      "Checking accuracy on validation set\n",
      "Got 316 / 1000 correct (31.60)\n",
      "\n",
      "Iteration 500, loss = 2.0894\n",
      "Checking accuracy on validation set\n",
      "Got 390 / 1000 correct (39.00)\n",
      "\n",
      "Iteration 600, loss = 1.7047\n",
      "Checking accuracy on validation set\n",
      "Got 424 / 1000 correct (42.40)\n",
      "\n",
      "Iteration 700, loss = 1.6730\n",
      "Checking accuracy on validation set\n",
      "Got 414 / 1000 correct (41.40)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "hidden_layer_size = 4000\n",
    "learning_rate = 1e-2\n",
    "model = TwoLayerFC(3*32*32, hidden_layer_size, 10)\n",
    "optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n",
    "\n",
    "train2(model, optimizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train a Three-Layer ConvNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 162,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration 0, loss = 3.0832\n",
      "Checking accuracy on validation set\n",
      "Got 110 / 1000 correct (11.00)\n",
      "\n",
      "Iteration 100, loss = 1.7194\n",
      "Checking accuracy on validation set\n",
      "Got 384 / 1000 correct (38.40)\n",
      "\n",
      "Iteration 200, loss = 1.7438\n",
      "Checking accuracy on validation set\n",
      "Got 417 / 1000 correct (41.70)\n",
      "\n",
      "Iteration 300, loss = 1.5973\n",
      "Checking accuracy on validation set\n",
      "Got 423 / 1000 correct (42.30)\n",
      "\n",
      "Iteration 400, loss = 1.3857\n",
      "Checking accuracy on validation set\n",
      "Got 461 / 1000 correct (46.10)\n",
      "\n",
      "Iteration 500, loss = 1.4872\n",
      "Checking accuracy on validation set\n",
      "Got 465 / 1000 correct (46.50)\n",
      "\n",
      "Iteration 600, loss = 1.4933\n",
      "Checking accuracy on validation set\n",
      "Got 478 / 1000 correct (47.80)\n",
      "\n",
      "Iteration 700, loss = 1.3025\n",
      "Checking accuracy on validation set\n",
      "Got 485 / 1000 correct (48.50)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "learning_rate = 3e-3\n",
    "channel_1 = 32\n",
    "channel_2 = 16\n",
    "\n",
    "params = {'in_channel': 3, 'channel_1': channel_1, 'channel_2': channel_1, 'num_classes': 10}\n",
    "model = ThreeLayerConvNet(params)\n",
    "optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n",
    "\n",
    "train2(model, optimizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "PyTorch Sequential API\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Two-Layer Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 163,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration 0, loss = 2.3646\n",
      "Checking accuracy on validation set\n",
      "Got 161 / 1000 correct (16.10)\n",
      "\n",
      "Iteration 100, loss = 1.9941\n",
      "Checking accuracy on validation set\n",
      "Got 380 / 1000 correct (38.00)\n",
      "\n",
      "Iteration 200, loss = 1.8165\n",
      "Checking accuracy on validation set\n",
      "Got 398 / 1000 correct (39.80)\n",
      "\n",
      "Iteration 300, loss = 1.5029\n",
      "Checking accuracy on validation set\n",
      "Got 424 / 1000 correct (42.40)\n",
      "\n",
      "Iteration 400, loss = 1.8718\n",
      "Checking accuracy on validation set\n",
      "Got 407 / 1000 correct (40.70)\n",
      "\n",
      "Iteration 500, loss = 1.4071\n",
      "Checking accuracy on validation set\n",
      "Got 427 / 1000 correct (42.70)\n",
      "\n",
      "Iteration 600, loss = 1.6137\n",
      "Checking accuracy on validation set\n",
      "Got 437 / 1000 correct (43.70)\n",
      "\n",
      "Iteration 700, loss = 1.6795\n",
      "Checking accuracy on validation set\n",
      "Got 445 / 1000 correct (44.50)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "class Flatten(nn.Module):\n",
    "    def forward(self, x):\n",
    "        return flatten(x)\n",
    "    \n",
    "hidden_layer_size = 4000\n",
    "learning_rate = 1e-2\n",
    "\n",
    "model = nn.Sequential(\n",
    "    Flatten(),\n",
    "    nn.Linear(3*32*32, hidden_layer_size),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(hidden_layer_size, 10))\n",
    "\n",
    "optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)\n",
    "train2(model, optimizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Three-Layer ConvNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 165,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration 0, loss = 2.3005\n",
      "Checking accuracy on validation set\n",
      "Got 105 / 1000 correct (10.50)\n",
      "\n",
      "Iteration 100, loss = 1.6214\n",
      "Checking accuracy on validation set\n",
      "Got 423 / 1000 correct (42.30)\n",
      "\n",
      "Iteration 200, loss = 1.5319\n",
      "Checking accuracy on validation set\n",
      "Got 520 / 1000 correct (52.00)\n",
      "\n",
      "Iteration 300, loss = 1.5658\n",
      "Checking accuracy on validation set\n",
      "Got 501 / 1000 correct (50.10)\n",
      "\n",
      "Iteration 400, loss = 1.2787\n",
      "Checking accuracy on validation set\n",
      "Got 535 / 1000 correct (53.50)\n",
      "\n",
      "Iteration 500, loss = 1.3388\n",
      "Checking accuracy on validation set\n",
      "Got 546 / 1000 correct (54.60)\n",
      "\n",
      "Iteration 600, loss = 1.1717\n",
      "Checking accuracy on validation set\n",
      "Got 575 / 1000 correct (57.50)\n",
      "\n",
      "Iteration 700, loss = 1.3267\n",
      "Checking accuracy on validation set\n",
      "Got 532 / 1000 correct (53.20)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "channel_1 = 32\n",
    "channel_2 = 16\n",
    "learning_rate = 1e-2\n",
    "\n",
    "model = nn.Sequential(\n",
    "    nn.Conv2d(in_channels=3, out_channels=channel_1, kernel_size=(5, 5), stride=1, padding=(2, 2), bias=True),\n",
    "    nn.ReLU(),\n",
    "    nn.Conv2d(in_channels=channel_1, out_channels=channel_2, kernel_size=(3, 3), stride=1, padding=(1, 1), bias=True),\n",
    "    nn.ReLU(),\n",
    "    Flatten(),\n",
    "    nn.Linear(channel_2*32*32, 10)\n",
    ")\n",
    "optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)\n",
    "\n",
    "train2(model, optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
